// -*- mode:c++ -*-

// Copyright (c) 2022 PLCT Lab
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met: redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer;
// redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution;
// neither the name of the copyright holders nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


def format VConfOp(code, write_code, declare_class, branch_class, *flags) {{
    iop = InstObjParams(
        name,
        Name,
        'VConfOp',
        {
            'code': code,
            'write_code': write_code,
        },
        flags
    )
    declareTemplate = eval(declare_class)
    branchTargetTemplate = eval(branch_class)

    header_output = declareTemplate.subst(iop)
    decoder_output = VConfConstructor.subst(iop)
    decode_block = VConfDecodeBlock.subst(iop)
    exec_output = VConfExecute.subst(iop) + branchTargetTemplate.subst(iop)
}};

def template VSetVlDeclare {{
    //
    // Static instruction class for "%(mnemonic)s".
    //
    class %(class_name)s : public %(base_class)s
    {
      private:
        %(reg_idx_arr_decl)s;
        VTYPE getNewVtype(VTYPE, VTYPE) const;
        uint32_t getNewVL(
            uint32_t, uint32_t, uint32_t, uint64_t, uint64_t) const;

      public:
        /// Constructor.
        %(class_name)s(ExtMachInst machInst, uint32_t elen, uint32_t vlen);
        Fault execute(ExecContext *, trace::InstRecord *) const override;
        std::unique_ptr<PCStateBase> branchTarget(
                ThreadContext *tc) const override;

        using StaticInst::branchTarget;
        using %(base_class)s::generateDisassembly;

    };
}};

def template VSetiVliDeclare {{
    //
    // Static instruction class for "%(mnemonic)s".
    //
    class %(class_name)s : public %(base_class)s
    {
      private:
        %(reg_idx_arr_decl)s;
        VTYPE getNewVtype(VTYPE, VTYPE) const;
        uint32_t getNewVL(
            uint32_t, uint32_t, uint32_t, uint64_t, uint64_t) const;

      public:
        /// Constructor.
        %(class_name)s(ExtMachInst machInst, uint32_t elen, uint32_t vlen);
        Fault execute(ExecContext *, trace::InstRecord *) const override;
        std::unique_ptr<PCStateBase> branchTarget(
                const PCStateBase &branch_pc) const override;

        using StaticInst::branchTarget;
        using %(base_class)s::generateDisassembly;

    };
}};

def template VConfConstructor {{
%(class_name)s::%(class_name)s(ExtMachInst _machInst, uint32_t _elen,
                               uint32_t _vlen)
    : %(base_class)s("%(mnemonic)s", _machInst, _elen, _vlen, %(op_class)s)
    {
        %(set_reg_idx_arr)s;
        %(constructor)s;
    }
}};

def template VConfDecodeBlock {{
    return new %(class_name)s(machInst, elen, vlen);
}};

def template VConfExecute {{
    VTYPE
    %(class_name)s::getNewVtype(
        VTYPE oldVtype, VTYPE reqVtype)  const
    {
        VTYPE newVtype = oldVtype;
        if (oldVtype != reqVtype) {
            newVtype = reqVtype;

            float vflmul = getVflmul(newVtype.vlmul);

            uint32_t sew = getSew(newVtype.vsew);

            uint32_t newVill =
                !(vflmul >= 0.125 && vflmul <= 8) ||
                    (float)sew > std::min(vflmul, 1.0f) * (float)elen ||
                    bits(reqVtype, 62, 8) != 0;
            if (newVill) {
                newVtype = 0;
                newVtype.vill = 1;
            }
        }
        return newVtype;
    }

    uint32_t
    %(class_name)s::getNewVL(uint32_t currentVl, uint32_t reqVl,
        uint32_t vlmax, uint64_t rdBits, uint64_t rs1Bits) const
    {
        uint32_t newVl = 0;
        if (vlmax == 0) {
            newVl = 0;
        } else if (rdBits == 0 && rs1Bits == 0) {
            newVl = currentVl > vlmax ? vlmax : currentVl;
        } else if (rdBits != 0 && rs1Bits == 0) {
            newVl = vlmax;
        } else if (rs1Bits != 0) {
            newVl = reqVl > vlmax ? vlmax : reqVl;
        }
        return newVl;
    }

    Fault
    %(class_name)s::execute(ExecContext *xc,
        trace::InstRecord *traceData) const
    {
        auto tc = xc->tcBase();
        MISA misa = xc->readMiscReg(MISCREG_ISA);
        STATUS status = xc->readMiscReg(MISCREG_STATUS);
        if (!misa.rvv || status.vs == VPUStatus::OFF) {
            return std::make_shared<IllegalInstFault>(
                "RVV is disabled or VPU is off", machInst);
        }

        %(op_decl)s;
        %(op_rd)s;
        %(code)s;

        tc->setMiscReg(MISCREG_VSTART, 0);

        VTYPE new_vtype = getNewVtype(Vtype, requested_vtype);
        vlmax = new_vtype.vill ? 0 : getVlmax(new_vtype, vlen);
        uint32_t new_vl = getNewVL(
            current_vl, requested_vl, vlmax, rd_bits, rs1_bits);



        %(write_code)s;

        %(op_wb)s;
        return NoFault;
    }
}};

def template VSetiVliBranchTarget {{
    std::unique_ptr<PCStateBase>
    %(class_name)s::branchTarget(const PCStateBase &branch_pc) const
    {
        auto &rpc = branch_pc.as<RiscvISA::PCState>();

        uint64_t rd_bits = machInst.rd;
        uint64_t rs1_bits = -1;
        uint64_t requested_vl = uimm;
        uint64_t requested_vtype = zimm10;

        VTYPE new_vtype = getNewVtype(rpc.vtype(), requested_vtype);
        uint32_t vlmax = new_vtype.vill ? 0 : getVlmax(new_vtype, vlen);
        uint32_t new_vl = getNewVL(
            rpc.vl(), requested_vl, vlmax, rd_bits, rs1_bits);

        std::unique_ptr<PCState> npc(dynamic_cast<PCState*>(rpc.clone()));
        npc->set(rvSext(npc->pc() + 4));
        npc->vtype(new_vtype);
        npc->vl(new_vl);
        return npc;
    }
}};

def template VSetVliBranchTarget {{
    std::unique_ptr<PCStateBase>
    %(class_name)s::branchTarget(ThreadContext *tc) const
    {
        PCStateBase *pc_ptr = tc->pcState().clone();

        uint64_t rd_bits = machInst.rd;
        uint64_t rs1_bits = machInst.rs1;
        uint64_t requested_vl = tc->getReg(srcRegIdx(0));
        uint64_t requested_vtype = zimm11;

        VTYPE new_vtype = getNewVtype(
            pc_ptr->as<PCState>().vtype(), requested_vtype);
        uint32_t vlmax = new_vtype.vill ? 0 : getVlmax(new_vtype, vlen);
        uint32_t new_vl = getNewVL(
            pc_ptr->as<PCState>().vl(), requested_vl, vlmax, rd_bits, rs1_bits);

        pc_ptr->as<PCState>().vtype(new_vtype);
        pc_ptr->as<PCState>().vl(new_vl);
        return std::unique_ptr<PCStateBase>{pc_ptr};
    }
}};

def template VSetVlBranchTarget {{
    std::unique_ptr<PCStateBase>
    %(class_name)s::branchTarget(ThreadContext *tc) const
    {
        PCStateBase *pc_ptr = tc->pcState().clone();

        uint64_t rd_bits = machInst.rd;
        uint64_t rs1_bits = machInst.rs1;
        uint64_t requested_vl = tc->getReg(srcRegIdx(0));
        uint64_t requested_vtype = tc->getReg(srcRegIdx(1));

        VTYPE new_vtype = getNewVtype(
            pc_ptr->as<PCState>().vtype(), requested_vtype);
        uint32_t vlmax = new_vtype.vill ? 0 : getVlmax(new_vtype, vlen);
        uint32_t new_vl = getNewVL(
            pc_ptr->as<PCState>().vl(), requested_vl, vlmax, rd_bits, rs1_bits);

        pc_ptr->as<PCState>().vtype(new_vtype);
        pc_ptr->as<PCState>().vl(new_vl);
        return std::unique_ptr<PCStateBase>{pc_ptr};
    }
}};
